# Faraday Penetration Test IDE
# Copyright (C) 2016  Infobyte LLC (http://www.infobytesec.com/)
# See the file 'doc/LICENSE' for the license information
from builtins import str, bytes
from io import TextIOWrapper

import threading
import logging
import csv

from flask import Blueprint, request, abort, jsonify, make_response
from flask_classful import route
from filteralchemy import (
    FilterSet,
    operators,
)
from flask_wtf.csrf import validate_csrf
from marshmallow import fields, ValidationError, Schema, post_load
from marshmallow.validate import OneOf
import wtforms


from faraday.server.api.base import (
    AutoSchema,
    FilterAlchemyMixin,
    FilterSetMeta,
    PaginatedMixin,
    ReadWriteView,
)

from faraday.server.schemas import (
    PrimaryKeyRelatedField,
    SeverityField,
    SelfNestedField,
    FaradayCustomField,
)

from faraday.server.models import (
    db,
    Vulnerability,
    VulnerabilityTemplate,
)

vulnerability_template_api = Blueprint('vulnerability_template_api', __name__)
logger = logging.getLogger(__name__)


class ImpactSchema(Schema):
    accountability = fields.Boolean(attribute='impact_accountability', default=False)
    availability = fields.Boolean(attribute='impact_availability', default=False)
    confidentiality = fields.Boolean(attribute='impact_confidentiality', default=False)
    integrity = fields.Boolean(attribute='impact_integrity', default=False)


class VulnerabilityTemplateSchema(AutoSchema):
    _id = fields.Integer(dump_only=True, attribute='id')
    id = fields.Integer(dump_only=True, attribute='id')
    _rev = fields.String(default='', dump_only=True)
    cwe = fields.String(dump_only=True, default='') # deprecated field, the legacy data is added to refs on import
    exploitation = SeverityField(attribute='severity', required=True)
    references = fields.Method('get_references', deserialize='load_references', required=True)
    refs = fields.List(fields.String(), dump_only=True, attribute='references')
    desc = fields.String(dump_only=True, attribute='description')
    data = fields.String(attribute='data')
    impact = SelfNestedField(ImpactSchema())
    easeofresolution = fields.String(
        attribute='ease_of_resolution',
        validate=OneOf(Vulnerability.EASE_OF_RESOLUTIONS),
        allow_none=True)
    policyviolations = fields.List(fields.String,
                                   attribute='policy_violations')
    creator = PrimaryKeyRelatedField('username', dump_only=True, attribute='creator')
    create_at = fields.DateTime(attribute='create_date',
                        dump_only=True)

    # Here we use vulnerability instead of vulnerability_template to avoid duplicate row
    # in the custom_fields_schema table.
    # All validation will be against vulnerability table.
    customfields = FaradayCustomField(table_name='vulnerability', attribute='custom_fields')
    external_id = fields.String(allow_none=True)

    class Meta:
        model = VulnerabilityTemplate
        fields = ('id', '_id', '_rev', 'cwe', 'description', 'desc',
                  'exploitation', 'name', 'references', 'refs', 'resolution',
                  'impact', 'easeofresolution', 'policyviolations', 'data',
                  'customfields', 'external_id', 'creator', 'create_at')

    def get_references(self, obj):
        return ', '.join(map(lambda ref_tmpl: ref_tmpl.name, obj.reference_template_instances))

    def load_references(self, value):
        if isinstance(value, bytes):
            value = value.decode('utf-8')
        if isinstance(value, list):
            references = value
        elif isinstance(value, str):
            if len(value) == 0:
                # Required because "".split(",") == [""]
                return []
            references = [ref.strip() for ref in value.split(',')]
        else:
            raise ValidationError('references must be a either a string '
                                  'or a list')
        if any(len(ref) == 0 for ref in references):
            raise ValidationError('Empty name detected in reference')
        return references

    @post_load
    def post_load_impact(self, data):
        # Unflatten impact (move data[impact][*] to data[*])
        impact = data.pop('impact', None)
        if impact:
            data.update(impact)
        return data


class VulnerabilityTemplateFilterSet(FilterSet):
    class Meta(FilterSetMeta):
        model = VulnerabilityTemplate  # It has all the fields
        fields = (
            'severity')
        operators = (operators.Equal,)


lock = threading.Lock()


class VulnerabilityTemplateView(PaginatedMixin,
                                FilterAlchemyMixin,
                                ReadWriteView):
    route_base = 'vulnerability_template'
    model_class = VulnerabilityTemplate
    schema_class = VulnerabilityTemplateSchema
    filterset_class = VulnerabilityTemplateFilterSet
    get_joinedloads = [VulnerabilityTemplate.creator]

    def _envelope_list(self, objects, pagination_metadata=None):
        vuln_tpls = []
        for template in objects:
            vuln_tpls.append({
                'id': template['_id'],
                'key': template['_id'],
                'value': {'rev': ''},
                'doc': template
            })
        return {
            'rows': vuln_tpls,
            'total_rows': len(objects)
        }

    def post(self, **kwargs):
        with lock:
            return super(VulnerabilityTemplateView, self).post(**kwargs)

    def _get_schema_instance(self, route_kwargs, **kwargs):
        schema = super(VulnerabilityTemplateView, self)._get_schema_instance(
            route_kwargs, **kwargs)

        return schema

    @route('/bulk_create/', methods=['POST'])
    def bulk_create(self):
        try:
            validate_csrf(request.form.get('csrf_token'))
        except wtforms.ValidationError:
            abort(403)
        logger.info("Create vulns template from CSV")
        if 'file' not in request.files:
            abort(400, "Missing File in request")
        vulns_file = request.files['file']
        FILE_HEADERS = {'cwe', 'name', 'description', 'resolution', 'exploitation', 'references'}
        try:
            io_wrapper = TextIOWrapper(vulns_file, encoding=request.content_encoding or "utf8")

            vulns_reader = csv.DictReader(io_wrapper, skipinitialspace=True)
            if set(vulns_reader.fieldnames) != FILE_HEADERS:
                logger.error("Missing Required headers in CSV (%s)", FILE_HEADERS)
                abort(400, "Missing Required headers in CSV (%s)" % FILE_HEADERS)
            vulns_created_count = 0
            vulns_with_errors_count = 0
            schema = self.schema_class()
            for vuln_dict in vulns_reader:
                try:
                    other_fields = {'customfields': [], 'create_at': '',  'creator': '', 'data': u'', 'desc': '', 'easeofresolution': None, 'id': '', 'impact': {
                                            'accountability': False,
                                            'availability': False,
                                            'confidentiality': False,
                                            'integrity': False
                                        }, 'policyviolations': [], "refs": [], "resolution": "",  "type": "vulnerability_template"}
                    vuln_dict.update(other_fields)
                    vuln_schema = schema.load(vuln_dict)
                    vuln_template = super(VulnerabilityTemplateView, self)._perform_create(vuln_schema.data)
                    db.session.commit()
                except Exception as e:
                    logger.error("Error creating vuln (%s)", e)
                    vulns_with_errors_count += 1
                else:
                    vulns_created_count += 1
            return make_response(jsonify(vulns_created=vulns_created_count, vulns_with_errors=vulns_with_errors_count), 200)
        except Exception as e:
            logger.error("Error parsing vulns CSV (%s)", e)
            abort(400, "Error parsing vulns CSV (%s)" % e)

VulnerabilityTemplateView.register(vulnerability_template_api)
